package com.alibaba.datax.plugin.rdbms.writer;

import com.alibaba.datax.common.element.Column;
import com.alibaba.datax.common.element.Record;
import com.alibaba.datax.common.exception.DataXException;
import com.alibaba.datax.common.plugin.RecordReceiver;
import com.alibaba.datax.common.plugin.TaskPluginCollector;
import com.alibaba.datax.common.util.Configuration;
import com.alibaba.datax.plugin.rdbms.util.DBUtil;
import com.alibaba.datax.plugin.rdbms.util.DBUtilErrorCode;
import com.alibaba.datax.plugin.rdbms.util.DataBaseType;
import com.alibaba.datax.plugin.rdbms.util.RdbmsException;
import com.alibaba.datax.plugin.rdbms.writer.util.OriginalConfPretreatmentUtil;
import com.alibaba.datax.plugin.rdbms.writer.util.WriterUtil;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Triple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.List;

public class CommonRdbmsWriter {

  public static class Job {

    private DataBaseType dataBaseType;

    private static final Logger LOG = LoggerFactory.getLogger(Job.class);

    public Job(DataBaseType dataBaseType) {
      this.dataBaseType = dataBaseType;
      OriginalConfPretreatmentUtil.DATABASE_TYPE = this.dataBaseType;
    }

    public void init(Configuration originalConfig) {
      OriginalConfPretreatmentUtil.doPretreatment(originalConfig, this.dataBaseType);

      LOG.debug("After job init(), originalConfig now is:[\n{}\n]", originalConfig.toJSON());
    }

    /*目前只支持MySQL Writer跟Oracle Writer;检查PreSQL跟PostSQL语法以及insert，delete权限*/
    public void writerPreCheck(Configuration originalConfig, DataBaseType dataBaseType) {
      /*检查PreSql跟PostSql语句*/
      prePostSqlValid(originalConfig, dataBaseType);
      /*检查insert 跟delete权限*/
      privilegeValid(originalConfig, dataBaseType);
    }

    public void prePostSqlValid(Configuration originalConfig, DataBaseType dataBaseType) {
      /*检查PreSql跟PostSql语句*/
      WriterUtil.preCheckPrePareSQL(originalConfig, dataBaseType);
      WriterUtil.preCheckPostSQL(originalConfig, dataBaseType);
    }

    public void privilegeValid(Configuration originalConfig, DataBaseType dataBaseType) {
      /*检查insert 跟delete权限*/
      String username = originalConfig.getString(Key.USERNAME);
      String password = originalConfig.getString(Key.PASSWORD);
      List<Object> connections = originalConfig.getList(Constant.CONN_MARK, Object.class);

      for (int i = 0, len = connections.size(); i < len; i++) {
        Configuration connConf = Configuration.from(connections.get(i).toString());
        String jdbcUrl = connConf.getString(Key.JDBC_URL);
        List<String> expandedTables = connConf.getList(Key.TABLE, String.class);
        boolean hasInsertPri = DBUtil
            .checkInsertPrivilege(dataBaseType, jdbcUrl, username, password, expandedTables);

        if (!hasInsertPri) {
          throw RdbmsException
              .asInsertPriException(dataBaseType, originalConfig.getString(Key.USERNAME), jdbcUrl);
        }

        if (DBUtil.needCheckDeletePrivilege(originalConfig)) {
          boolean hasDeletePri = DBUtil
              .checkDeletePrivilege(dataBaseType, jdbcUrl, username, password, expandedTables);
          if (!hasDeletePri) {
            throw RdbmsException
                .asDeletePriException(dataBaseType, originalConfig.getString(Key.USERNAME),
                    jdbcUrl);
          }
        }
      }
    }

    // 一般来说，是需要推迟到 task 中进行pre 的执行（单表情况例外）
    public void prepare(Configuration originalConfig) {
      int tableNumber = originalConfig.getInt(Constant.TABLE_NUMBER_MARK);
      if (tableNumber == 1) {
        String username = originalConfig.getString(Key.USERNAME);
        String password = originalConfig.getString(Key.PASSWORD);

        List<Object> conns = originalConfig.getList(Constant.CONN_MARK, Object.class);
        Configuration connConf = Configuration.from(conns.get(0).toString());

        // 这里的 jdbcUrl 已经 append 了合适后缀参数
        String jdbcUrl = connConf.getString(Key.JDBC_URL);
        originalConfig.set(Key.JDBC_URL, jdbcUrl);

        String table = connConf.getList(Key.TABLE, String.class).get(0);
        originalConfig.set(Key.TABLE, table);

        List<String> preSqls = originalConfig.getList(Key.PRE_SQL, String.class);
        List<String> renderedPreSqls = WriterUtil.renderPreOrPostSqls(preSqls, table);

        originalConfig.remove(Constant.CONN_MARK);
        if (null != renderedPreSqls && !renderedPreSqls.isEmpty()) {
          // 说明有 preSql 配置，则此处删除掉
          originalConfig.remove(Key.PRE_SQL);

          Connection conn = DBUtil.getConnection(dataBaseType, jdbcUrl, username, password);
          LOG.info("Begin to execute preSqls:[{}]. context info:{}.",
              StringUtils.join(renderedPreSqls, ";"),
              jdbcUrl);

          WriterUtil.executeSqls(conn, renderedPreSqls, jdbcUrl, dataBaseType);
          DBUtil.closeDBResources(null, null, conn);
        }
      }

      LOG.debug("After job prepare(), originalConfig now is:[\n{}\n]", originalConfig.toJSON());
    }

    public List<Configuration> split(Configuration originalConfig, int mandatoryNumber) {
      return WriterUtil.doSplit(originalConfig, mandatoryNumber);
    }

    // 一般来说，是需要推迟到 task 中进行post 的执行（单表情况例外）
    public void post(Configuration originalConfig) {
      int tableNumber = originalConfig.getInt(Constant.TABLE_NUMBER_MARK);
      if (tableNumber == 1) {
        String username = originalConfig.getString(Key.USERNAME);
        String password = originalConfig.getString(Key.PASSWORD);

        // 已经由 prepare 进行了appendJDBCSuffix处理
        String jdbcUrl = originalConfig.getString(Key.JDBC_URL);

        String table = originalConfig.getString(Key.TABLE);

        List<String> postSqls = originalConfig.getList(Key.POST_SQL, String.class);
        List<String> renderedPostSqls = WriterUtil.renderPreOrPostSqls(postSqls, table);

        if (null != renderedPostSqls && !renderedPostSqls.isEmpty()) {
          // 说明有 postSql 配置，则此处删除掉
          originalConfig.remove(Key.POST_SQL);

          Connection conn = DBUtil.getConnection(this.dataBaseType, jdbcUrl, username, password);

          LOG.info("Begin to execute postSqls:[{}]. context info:{}.",
              StringUtils.join(renderedPostSqls, ";"), jdbcUrl);
          WriterUtil.executeSqls(conn, renderedPostSqls, jdbcUrl, dataBaseType);
          DBUtil.closeDBResources(null, null, conn);
        }
      }
    }

    public void destroy(Configuration originalConfig) {
    }

  }

  public static class Task {

    protected static final Logger LOG = LoggerFactory.getLogger(Task.class);

    protected DataBaseType dataBaseType;

    private static final String VALUE_HOLDER = "?";

    protected String username;

    protected String password;

    protected String jdbcUrl;

    protected String table;

    protected List<String> columns;

    protected List<String> preSqls;

    protected List<String> postSqls;

    protected int batchSize;

    protected int batchByteSize;

    protected int columnNumber = 0;

    protected TaskPluginCollector taskPluginCollector;

    // 作为日志显示信息时，需要附带的通用信息。比如信息所对应的数据库连接等信息，针对哪个表做的操作
    protected static String BASIC_MESSAGE;

    protected static String INSERT_OR_REPLACE_TEMPLATE;

    protected String writeRecordSql;

    protected String writeMode;

    protected boolean emptyAsNull;

    protected Triple<List<String>, List<Integer>, List<String>> resultSetMetaData;

    public Task(DataBaseType dataBaseType) {
      this.dataBaseType = dataBaseType;
    }

    public void init(Configuration writerSliceConfig) {
      this.username = writerSliceConfig.getString(Key.USERNAME);
      this.password = writerSliceConfig.getString(Key.PASSWORD);
      this.jdbcUrl = writerSliceConfig.getString(Key.JDBC_URL);

      //ob10的处理
      if (this.jdbcUrl.startsWith(Constant.OB10_SPLIT_STRING)
          && this.dataBaseType == DataBaseType.MySql) {
        String[] ss = this.jdbcUrl.split(Constant.OB10_SPLIT_STRING_PATTERN);
        if (ss.length != 3) {
          throw DataXException
              .asDataXException(DBUtilErrorCode.JDBC_OB10_ADDRESS_ERROR,
                  "JDBC OB10格式错误，请联系askdatax");
        }
        LOG.info("this is ob1_0 jdbc url.");
        this.username = ss[1].trim() + ":" + this.username;
        this.jdbcUrl = ss[2];
        LOG.info("this is ob1_0 jdbc url. user=" + this.username + " :url=" + this.jdbcUrl);
      }

      this.table = writerSliceConfig.getString(Key.TABLE);

      this.columns = writerSliceConfig.getList(Key.COLUMN, String.class);
      this.columnNumber = this.columns.size();

      this.preSqls = writerSliceConfig.getList(Key.PRE_SQL, String.class);
      this.postSqls = writerSliceConfig.getList(Key.POST_SQL, String.class);
      this.batchSize = writerSliceConfig.getInt(Key.BATCH_SIZE, Constant.DEFAULT_BATCH_SIZE);
      this.batchByteSize = writerSliceConfig
          .getInt(Key.BATCH_BYTE_SIZE, Constant.DEFAULT_BATCH_BYTE_SIZE);

      writeMode = writerSliceConfig.getString(Key.WRITE_MODE, "INSERT");
      emptyAsNull = writerSliceConfig.getBool(Key.EMPTY_AS_NULL, true);
      INSERT_OR_REPLACE_TEMPLATE = writerSliceConfig
          .getString(Constant.INSERT_OR_REPLACE_TEMPLATE_MARK);
      this.writeRecordSql = String.format(INSERT_OR_REPLACE_TEMPLATE, this.table);

      BASIC_MESSAGE = String.format("jdbcUrl:[%s], table:[%s]", this.jdbcUrl, this.table);
    }

    public void prepare(Configuration writerSliceConfig) {
      Connection connection = DBUtil
          .getConnection(this.dataBaseType, this.jdbcUrl, username, password);

      DBUtil.dealWithSessionConfig(connection, writerSliceConfig, this.dataBaseType, BASIC_MESSAGE);

      int tableNumber = writerSliceConfig.getInt(Constant.TABLE_NUMBER_MARK);
      if (tableNumber != 1) {
        LOG.info("Begin to execute preSqls:[{}]. context info:{}.",
            StringUtils.join(this.preSqls, ";"),
            BASIC_MESSAGE);
        WriterUtil.executeSqls(connection, this.preSqls, BASIC_MESSAGE, dataBaseType);
      }

      DBUtil.closeDBResources(null, null, connection);
    }

    public void startWriteWithConnection(RecordReceiver recordReceiver,
        TaskPluginCollector taskPluginCollector,
        Connection connection) {
      this.taskPluginCollector = taskPluginCollector;

      // 用于写入数据的时候的类型根据目的表字段类型转换
      this.resultSetMetaData = DBUtil
          .getColumnMetaData(connection, this.table, StringUtils.join(this.columns, ","));
      // 写数据库的SQL语句
      calcWriteRecordSql();

      List<Record> writeBuffer = new ArrayList<Record>(this.batchSize);
      int bufferBytes = 0;
      try {
        Record record;
        while ((record = recordReceiver.getFromReader()) != null) {
          if (record.getColumnNumber() != this.columnNumber) {
            // 源头读取字段列数与目的表字段写入列数不相等，直接报错
            throw DataXException.asDataXException(DBUtilErrorCode.CONF_ERROR,
                String.format("列配置信息有错误. 因为您配置的任务中，源头读取字段数:%s 与 目的表要写入的字段数:%s 不相等. 请检查您的配置并作出修改.",
                    record.getColumnNumber(), this.columnNumber));
          }

          writeBuffer.add(record);
          bufferBytes += record.getMemorySize();

          if (writeBuffer.size() >= batchSize || bufferBytes >= batchByteSize) {
            doBatchInsert(connection, writeBuffer);
            writeBuffer.clear();
            bufferBytes = 0;
          }
        }
        if (!writeBuffer.isEmpty()) {
          doBatchInsert(connection, writeBuffer);
          writeBuffer.clear();
          bufferBytes = 0;
        }
      } catch (Exception e) {
        throw DataXException.asDataXException(DBUtilErrorCode.WRITE_DATA_ERROR, e);
      } finally {
        writeBuffer.clear();
        bufferBytes = 0;
        DBUtil.closeDBResources(null, null, connection);
      }
    }

    // TODO 改用连接池，确保每次获取的连接都是可用的（注意：连接可能需要每次都初始化其 session）
    public void startWrite(RecordReceiver recordReceiver, Configuration writerSliceConfig,
        TaskPluginCollector taskPluginCollector) {
      Connection connection = DBUtil
          .getConnection(this.dataBaseType, this.jdbcUrl, username, password);
      DBUtil.dealWithSessionConfig(connection, writerSliceConfig, this.dataBaseType, BASIC_MESSAGE);
      startWriteWithConnection(recordReceiver, taskPluginCollector, connection);
    }


    public void post(Configuration writerSliceConfig) {
      int tableNumber = writerSliceConfig.getInt(Constant.TABLE_NUMBER_MARK);

      boolean hasPostSql = (this.postSqls != null && this.postSqls.size() > 0);
      if (tableNumber == 1 || !hasPostSql) {
        return;
      }

      Connection connection = DBUtil
          .getConnection(this.dataBaseType, this.jdbcUrl, username, password);

      LOG.info("Begin to execute postSqls:[{}]. context info:{}.",
          StringUtils.join(this.postSqls, ";"),
          BASIC_MESSAGE);
      WriterUtil.executeSqls(connection, this.postSqls, BASIC_MESSAGE, dataBaseType);
      DBUtil.closeDBResources(null, null, connection);
    }

    public void destroy(Configuration writerSliceConfig) {
    }

    protected void doBatchInsert(Connection connection, List<Record> buffer) throws SQLException {
      PreparedStatement preparedStatement = null;
      try {
        connection.setAutoCommit(false);
        preparedStatement = connection.prepareStatement(this.writeRecordSql);

        for (Record record : buffer) {
          preparedStatement = fillPreparedStatement(preparedStatement, record);
          preparedStatement.addBatch();
        }
        preparedStatement.executeBatch();
        connection.commit();
      } catch (SQLException e) {
        LOG.warn("回滚此次写入, 采用每次写入一行方式提交. 因为:" + e.getMessage());
        connection.rollback();
        doOneInsert(connection, buffer);
      } catch (Exception e) {
        throw DataXException.asDataXException(DBUtilErrorCode.WRITE_DATA_ERROR, e);
      } finally {
        DBUtil.closeDBResources(preparedStatement, null);
      }
    }

    protected void doOneInsert(Connection connection, List<Record> buffer) {
      PreparedStatement preparedStatement = null;
      try {
        connection.setAutoCommit(true);
        preparedStatement = connection.prepareStatement(this.writeRecordSql);

        for (Record record : buffer) {
          try {
            preparedStatement = fillPreparedStatement(preparedStatement, record);
            preparedStatement.execute();
          } catch (SQLException e) {
            LOG.debug(e.toString());

            this.taskPluginCollector.collectDirtyRecord(record, e);
          } finally {
            // 最后不要忘了关闭 preparedStatement
            preparedStatement.clearParameters();
          }
        }
      } catch (Exception e) {
        throw DataXException.asDataXException(DBUtilErrorCode.WRITE_DATA_ERROR, e);
      } finally {
        DBUtil.closeDBResources(preparedStatement, null);
      }
    }

    // 直接使用了两个类变量：columnNumber,resultSetMetaData
    protected PreparedStatement fillPreparedStatement(PreparedStatement preparedStatement,
        Record record)
        throws SQLException {
      for (int i = 0; i < this.columnNumber; i++) {
        int columnSqltype = this.resultSetMetaData.getMiddle().get(i);
        preparedStatement = fillPreparedStatementColumnType(preparedStatement, i, columnSqltype,
            record.getColumn(i));
      }

      return preparedStatement;
    }

    protected PreparedStatement fillPreparedStatementColumnType(PreparedStatement preparedStatement,
        int columnIndex, int columnSqlType, Column column) throws SQLException {
      java.util.Date utilDate;
      switch (columnSqlType) {
        case Types.CHAR:
        case Types.NCHAR:
        case Types.CLOB:
        case Types.NCLOB:
        case Types.VARCHAR:
        case Types.LONGVARCHAR:
        case Types.NVARCHAR:
        case Types.LONGNVARCHAR:
          preparedStatement.setString(columnIndex + 1, column.asString());
          break;

        case Types.SMALLINT:
        case Types.INTEGER:
        case Types.BIGINT:
        case Types.NUMERIC:
        case Types.DECIMAL:
        case Types.FLOAT:
        case Types.REAL:
        case Types.DOUBLE:
          String strValue = column.asString();
          if (emptyAsNull && "".equals(strValue)) {
            preparedStatement.setString(columnIndex + 1, null);
          } else {
            preparedStatement.setString(columnIndex + 1, strValue);
          }
          break;

        //tinyint is a little special in some database like mysql {boolean->tinyint(1)}
        case Types.TINYINT:
          Long longValue = column.asLong();
          if (null == longValue) {
            preparedStatement.setString(columnIndex + 1, null);
          } else {
            preparedStatement.setString(columnIndex + 1, longValue.toString());
          }
          break;

        // for mysql bug, see http://bugs.mysql.com/bug.php?id=35115
        case Types.DATE:
          if (this.resultSetMetaData.getRight().get(columnIndex).equalsIgnoreCase("year")) {
            if (column.asBigInteger() == null) {
              preparedStatement.setString(columnIndex + 1, null);
            } else {
              preparedStatement.setInt(columnIndex + 1, column.asBigInteger().intValue());
            }
          } else {
            java.sql.Date sqlDate = null;
            try {
              utilDate = column.asDate();
            } catch (DataXException e) {
              throw new SQLException(String.format("Date 类型转换错误：[%s]", column));
            }

            if (null != utilDate) {
              sqlDate = new java.sql.Date(utilDate.getTime());
            }
            preparedStatement.setDate(columnIndex + 1, sqlDate);
          }
          break;

        case Types.TIME:
          java.sql.Time sqlTime = null;
          try {
            utilDate = column.asDate();
          } catch (DataXException e) {
            throw new SQLException(String.format("TIME 类型转换错误：[%s]", column));
          }

          if (null != utilDate) {
            sqlTime = new java.sql.Time(utilDate.getTime());
          }
          preparedStatement.setTime(columnIndex + 1, sqlTime);
          break;

        case Types.TIMESTAMP:
          java.sql.Timestamp sqlTimestamp = null;
          try {
            utilDate = column.asDate();
          } catch (DataXException e) {
            throw new SQLException(String.format("TIMESTAMP 类型转换错误：[%s]", column));
          }

          if (null != utilDate) {
            sqlTimestamp = new java.sql.Timestamp(utilDate.getTime());
          }
          preparedStatement.setTimestamp(columnIndex + 1, sqlTimestamp);
          break;

        case Types.BINARY:
        case Types.VARBINARY:
        case Types.BLOB:
        case Types.LONGVARBINARY:
          preparedStatement.setBytes(columnIndex + 1, column.asBytes());
          break;

        case Types.BOOLEAN:
          preparedStatement.setString(columnIndex + 1, column.asString());
          break;

        // warn: bit(1) -> Types.BIT 可使用setBoolean
        // warn: bit(>1) -> Types.VARBINARY 可使用setBytes
        case Types.BIT:
          if (this.dataBaseType == DataBaseType.MySql) {
            preparedStatement.setBoolean(columnIndex + 1, column.asBoolean());
          } else {
            preparedStatement.setString(columnIndex + 1, column.asString());
          }
          break;
        default:
          throw DataXException.asDataXException(DBUtilErrorCode.UNSUPPORTED_TYPE, String.format(
              "您的配置文件中的列配置信息有误. 因为DataX 不支持数据库写入这种字段类型. 字段名:[%s], 字段类型:[%d], 字段Java类型:[%s]. 请修改表中该字段的类型或者不同步该字段.",
              this.resultSetMetaData.getLeft().get(columnIndex),
              this.resultSetMetaData.getMiddle().get(columnIndex),
              this.resultSetMetaData.getRight().get(columnIndex)));
      }
      return preparedStatement;
    }

    private void calcWriteRecordSql() {
      if (!VALUE_HOLDER.equals(calcValueHolder(""))) {
        List<String> valueHolders = new ArrayList<>(columnNumber);
        for (int i = 0; i < columns.size(); i++) {
          String type = resultSetMetaData.getRight().get(i);
          valueHolders.add(calcValueHolder(type));
        }

        boolean forceUseUpdate = false;
        //ob10的处理
        if (dataBaseType != null && dataBaseType == DataBaseType.MySql
            && OriginalConfPretreatmentUtil
            .isOB10(jdbcUrl)) {
          forceUseUpdate = true;
        }

        INSERT_OR_REPLACE_TEMPLATE = WriterUtil
            .getWriteTemplate(columns, valueHolders, writeMode, dataBaseType, forceUseUpdate);
        writeRecordSql = String.format(INSERT_OR_REPLACE_TEMPLATE, this.table);
      }
    }

    protected String calcValueHolder(String columnType) {
      return VALUE_HOLDER;
    }
  }
}
